import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from matplotlib import rcParams

# 设置字体为 SimHei（黑体），确保中文显示正常
rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
rcParams['axes.unicode_minus'] = False  # 解决负号显示问题


def plot_feature_importance_rf(xlsx_file, select_feature, target_feature, output_file="plot_feature_importance_rf.png"):
    """
    功能：使用随机森林模型训练数据，计算特征重要性，并绘制特征重要性图。

    参数：
        - xlsx_file: 数据文件路径，支持 .xlsx 文件。
        - select_feature: 用于训练的特征列名称列表。
        - target_feature: 目标列名称（即要预测的列）。
        - output_file: 输出图像的文件路径，默认为 "plot_feature_importance_rf.png"。

    返回：
        - 无返回值，但会生成特征重要性图并保存到指定路径。
    """

    # 1. 加载数据集
    df = pd.read_excel(xlsx_file)
    X = df[select_feature]  # 特征数据
    y = df[target_feature]  # 目标标签

    # 2. 拆分数据集为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    # 3. 初始化并训练随机森林回归模型
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)

    # 4. 获取特征重要性
    importance = model.feature_importances_

    # 5. 将特征重要性按值排序，并打印前5个最重要的特征
    sorted_importance = sorted(zip(select_feature, importance), key=lambda x: x[1], reverse=True)
    for index, value in enumerate(sorted_importance[:5]):
        print(f"排名第{index + 1}的特征: {value[0]}, 重要性分数: {value[1]:.4f}")

    # 6. 绘制特征重要性图
    plt.figure(figsize=(10, 6))
    plt.barh([x[0] for x in sorted_importance], [x[1] for x in sorted_importance], color='skyblue')
    plt.xlabel("特征重要性分数")
    plt.ylabel("特征名称")
    plt.title("随机森林特征重要性")
    plt.gca().invert_yaxis()  # 将特征按重要性从高到低排列
    plt.tight_layout()  # 调整布局，确保标签显示完整

    # 7. 保存图像到指定路径
    plt.savefig(output_file)
    print(f"特征重要性图已保存为: {output_file}")
    plt.show()  # 显示图像


# 调用函数，绘制特征重要性图
plot_feature_importance_rf(
    "抗压强度.xlsx",  # 数据文件路径
    ["水用量（kg/m3）", "水泥ID", "水泥用量（kg/m3）", "粉煤灰用量（kg/m3）", "砂ID", "砂用量（kg/m3）",
     "石ID", "石用量（kg/m3）", "减水剂ID", "减水剂掺量（%）", "增效剂ID", "增效剂掺量（%）"],  # 特征列
    "7d抗压（MPa）"  # 目标列
)